{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training neural network with DALI and Flax\n",
    "\n",
    "This simple example shows how to train a neural network implemented in Flax with DALI pipelines. If you want to learn more about training neural networks with Flax, look into [Flax Getting Started](https://flax.readthedocs.io/en/latest/getting_started.html) example.\n",
    "\n",
    "DALI setup is very similar to the [training example with pure JAX](jax-basic_example.ipynb). The only difference is the addition of a trailing dimension to the returned image to make it compatible with Flax convolutions. If you are not familiar with how to use DALI with JAX you can learn more in the [DALI and JAX Getting Started](jax-getting_started.ipynb) example.\n",
    "\n",
    "We use MNIST in Caffe2 format from [DALI_extra](https://github.com/NVIDIA/DALI_extra)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:41.850101Z",
     "iopub.status.busy": "2023-07-28T07:43:41.849672Z",
     "iopub.status.idle": "2023-07-28T07:43:41.853520Z",
     "shell.execute_reply": "2023-07-28T07:43:41.852990Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "training_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/training/\"\n",
    ")\n",
    "validation_data_path = os.path.join(\n",
    "    os.environ[\"DALI_EXTRA_PATH\"], \"db/MNIST/testing/\"\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First step is to create an iterator definition function that will later be used to create instances of DALI iterators. It defines all steps of the preprocessing. In this simple example we have `fn.readers.caffe2` for reading data in Caffe2 format, `fn.decoders.image` for image decoding, `fn.crop_mirror_normalize` used to normalize the images and `fn.reshape` to adjust the shape of the output tensors. We also move the labels from the CPU to the GPU memory with `labels.gpu()` and apply one hot encoding to them for training with `fn.one_hot`.\n",
    "\n",
    "This example focuses on how to use DALI pipeline with JAX. For more information on DALI iterator look into [DALI and JAX getting started](jax-getting_started.ipynb) and [pipeline documentation](../../../pipeline.rst)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:41.855441Z",
     "iopub.status.busy": "2023-07-28T07:43:41.855301Z",
     "iopub.status.idle": "2023-07-28T07:43:41.986406Z",
     "shell.execute_reply": "2023-07-28T07:43:41.985500Z"
    }
   },
   "outputs": [],
   "source": [
    "import nvidia.dali.fn as fn\n",
    "import nvidia.dali.types as types\n",
    "\n",
    "from nvidia.dali.plugin.jax import data_iterator\n",
    "\n",
    "\n",
    "batch_size = 50\n",
    "image_size = 28\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "@data_iterator(\n",
    "    output_map=[\"images\", \"labels\"], reader_name=\"mnist_caffe2_reader\"\n",
    ")\n",
    "def mnist_iterator(data_path, is_training):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path, random_shuffle=is_training, name=\"mnist_caffe2_reader\"\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, std=[255.0])\n",
    "    images = fn.reshape(images, shape=[-1])  # Flatten the output image\n",
    "\n",
    "    labels = labels.gpu()\n",
    "\n",
    "    if is_training:\n",
    "        labels = fn.one_hot(labels, num_classes=num_classes)\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the iterator definition function we can now create DALI iterators."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:42.109692Z",
     "iopub.status.busy": "2023-07-28T07:43:42.109536Z",
     "iopub.status.idle": "2023-07-28T07:43:43.557199Z",
     "shell.execute_reply": "2023-07-28T07:43:43.556686Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating iterators\n",
      "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc240f4e50>\n",
      "<nvidia.dali.plugin.jax.iterator.DALIGenericIterator object at 0x7fdc1c78e020>\n",
      "Number of batches in training iterator = 1200\n",
      "Number of batches in validation iterator = 200\n"
     ]
    }
   ],
   "source": [
    "print(\"Creating iterators\")\n",
    "training_iterator = mnist_iterator(\n",
    "    data_path=training_data_path, is_training=True, batch_size=batch_size\n",
    ")\n",
    "validation_iterator = mnist_iterator(\n",
    "    data_path=validation_data_path, is_training=False, batch_size=batch_size\n",
    ")\n",
    "\n",
    "print(training_iterator)\n",
    "print(validation_iterator)\n",
    "\n",
    "print(f\"Number of batches in training iterator = {len(training_iterator)}\")\n",
    "print(f\"Number of batches in validation iterator = {len(validation_iterator)}\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the setup above, DALI iterators are ready for the training. \n",
    "\n",
    "Now we need to setup model and training utilities. The goal of this notebook is not to explain Flax concepts. We want to show how to train models implemented in Flax with DALI as a data loading and preprocessing library. We used standard Flax tools do define simple neural network. We have functions to create an instance of this network, run one training step on it and calculate accuracy of the model at the end of each epoch.\n",
    "\n",
    "If you want to learn more about Flax and get better understanding of the code below, look into [Flax Documentation](https://flax.readthedocs.io/en/latest/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:43.559575Z",
     "iopub.status.busy": "2023-07-28T07:43:43.559420Z",
     "iopub.status.idle": "2023-07-28T07:43:43.618221Z",
     "shell.execute_reply": "2023-07-28T07:43:43.617532Z"
    }
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "from flax import linen as nn\n",
    "from flax.training import train_state\n",
    "\n",
    "import optax\n",
    "\n",
    "\n",
    "class CNN(nn.Module):\n",
    "    @nn.compact\n",
    "    def __call__(self, x):\n",
    "        x = nn.Dense(features=784)(x)\n",
    "        x = nn.relu(x)\n",
    "        x = nn.Dense(features=1024)(x)\n",
    "        x = nn.relu(x)\n",
    "        x = nn.Dense(features=1024)(x)\n",
    "        x = nn.relu(x)\n",
    "        x = nn.Dense(features=10)(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "def create_model_state(rng, learning_rate, momentum):\n",
    "    cnn = CNN()\n",
    "    params = cnn.init(rng, jnp.ones([784]))[\"params\"]\n",
    "    tx = optax.sgd(learning_rate, momentum)\n",
    "    return train_state.TrainState.create(\n",
    "        apply_fn=cnn.apply, params=params, tx=tx\n",
    "    )\n",
    "\n",
    "\n",
    "@jax.jit\n",
    "def train_step(model_state, batch):\n",
    "    def loss_fn(params):\n",
    "        logits = model_state.apply_fn({\"params\": params}, batch[\"images\"])\n",
    "        loss = optax.softmax_cross_entropy(\n",
    "            logits=logits, labels=batch[\"labels\"]\n",
    "        ).mean()\n",
    "        return loss\n",
    "\n",
    "    grad_fn = jax.grad(loss_fn)\n",
    "    grads = grad_fn(model_state.params)\n",
    "    model_state = model_state.apply_gradients(grads=grads)\n",
    "    return model_state\n",
    "\n",
    "\n",
    "def accuracy(model_state, iterator):\n",
    "    correct_predictions = 0\n",
    "    for batch in iterator:\n",
    "        logits = model_state.apply_fn(\n",
    "            {\"params\": model_state.params}, batch[\"images\"]\n",
    "        )\n",
    "        correct_predictions = correct_predictions + jnp.sum(\n",
    "            batch[\"labels\"].ravel() == jnp.argmax(logits, axis=-1)\n",
    "        )\n",
    "\n",
    "    return correct_predictions / iterator.size"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With utilities defined above, we can create an instance of the model we want to train."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-07-28T07:43:43.622376Z",
     "iopub.status.busy": "2023-07-28T07:43:43.621205Z",
     "iopub.status.idle": "2023-07-28T07:43:58.016073Z",
     "shell.execute_reply": "2023-07-28T07:43:58.015333Z"
    }
   },
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)\n",
    "rng, init_rng = jax.random.split(rng)\n",
    "\n",
    "learning_rate = 0.1\n",
    "momentum = 0.9\n",
    "\n",
    "model_state = create_model_state(init_rng, learning_rate, momentum)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "At this point, everything is ready to run the training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Starting training\n",
      "Epoch 0\n",
      "Accuracy = 0.9551000595092773\n",
      "Epoch 1\n",
      "Accuracy = 0.9691000580787659\n",
      "Epoch 2\n",
      "Accuracy = 0.9738000631332397\n",
      "Epoch 3\n",
      "Accuracy = 0.9622000455856323\n",
      "Epoch 4\n",
      "Accuracy = 0.9604000449180603\n"
     ]
    }
   ],
   "source": [
    "print(\"Starting training\")\n",
    "\n",
    "num_epochs = 5\n",
    "for epoch in range(num_epochs):\n",
    "    print(f\"Epoch {epoch}\")\n",
    "    for batch in training_iterator:\n",
    "        model_state = train_step(model_state, batch)\n",
    "\n",
    "    acc = accuracy(model_state, validation_iterator)\n",
    "    print(f\"Accuracy = {acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multiple GPUs with DALI and FLAX\n",
    "\n",
    "This section shows how to extend the example above to use multiple GPUs.\n",
    "\n",
    "Again, we start with creating an iterator definition function. It is a slightly modified version of the function we have seen before.\n",
    "\n",
    "Note the new arguments passed to `fn.readers.caffe2`, `num_shards` and `shard_id`. They are used to control sharding:\n",
    " - `num_shards` sets the total number of shards\n",
    " - `shard_id` tells the pipeline for which shard in the training it is responsible. \n",
    "\n",
    "We add `devices` argument to the decorator to specify which devices we want to use. Here we use all GPUs available to JAX on the machine."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 200\n",
    "image_size = 28\n",
    "num_classes = 10\n",
    "\n",
    "\n",
    "@data_iterator(\n",
    "    output_map=[\"images\", \"labels\"],\n",
    "    reader_name=\"mnist_caffe2_reader\",\n",
    "    devices=jax.devices(),\n",
    ")\n",
    "def mnist_sharded_iterator(data_path, is_training, num_shards, shard_id):\n",
    "    jpegs, labels = fn.readers.caffe2(\n",
    "        path=data_path,\n",
    "        random_shuffle=is_training,\n",
    "        name=\"mnist_caffe2_reader\",\n",
    "        num_shards=num_shards,\n",
    "        shard_id=shard_id,\n",
    "    )\n",
    "    images = fn.decoders.image(jpegs, device=\"mixed\", output_type=types.GRAY)\n",
    "    images = fn.crop_mirror_normalize(\n",
    "        images, dtype=types.FLOAT, std=[255.0], output_layout=\"CHW\"\n",
    "    )\n",
    "    images = fn.reshape(images, shape=[-1])  # Flatten the output image\n",
    "\n",
    "    labels = labels.gpu()\n",
    "\n",
    "    if is_training:\n",
    "        labels = fn.one_hot(labels, num_classes=num_classes)\n",
    "\n",
    "    return images, labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "With the iterator definition function we can now create DALI iterators for training on multiple GPUs. This iterator will return outputs compatible with `pmapped` JAX functions. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating training iterator\n",
      "Number of batches in training iterator = 300\n"
     ]
    }
   ],
   "source": [
    "print(\"Creating training iterator\")\n",
    "training_iterator = mnist_sharded_iterator(\n",
    "    data_path=training_data_path, is_training=True, batch_size=batch_size\n",
    ")\n",
    "\n",
    "print(f\"Number of batches in training iterator = {len(training_iterator)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For simplicity, we will run validation on one GPU. We can reuse the validation iterator from the single GPU example. The only difference is that we will need to pull the model to the same GPU. In real life scenario this might be costly but for this toy educational example is suficient. \n",
    "\n",
    "\n",
    "For the model to be compatible with pmap-style multiple GPU training we need to replicate it. If you want to learn more about training on multiple GPUs with `pmap` you can look into [Parallel Evaluation in JAX](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html) from the JAX documentation and [Ensembling on multiple devices](https://flax.readthedocs.io/en/latest/guides/ensembling.html#ensembling-on-multiple-devices) from Flax documentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = jax.random.PRNGKey(0)\n",
    "rng, init_rng = jax.random.split(rng)\n",
    "\n",
    "learning_rate = 0.1\n",
    "momentum = 0.9\n",
    "\n",
    "model_state = jax.pmap(create_model_state, static_broadcasted_argnums=(1, 2))(\n",
    "    jax.random.split(init_rng, jax.device_count()), learning_rate, momentum\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since we want to run validation on single GPU, we extract only one replica of the model and pass it to `accuracy` function. \n",
    "\n",
    "Now, we are ready to train Flax model on multiple GPUs with DALI as the data source."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Accuracy = 0.9509000182151794\n",
      "Epoch 1\n",
      "Accuracy = 0.9643000364303589\n",
      "Epoch 2\n",
      "Accuracy = 0.9724000692367554\n",
      "Epoch 3\n",
      "Accuracy = 0.9701000452041626\n",
      "Epoch 4\n",
      "Accuracy = 0.9758000373840332\n"
     ]
    }
   ],
   "source": [
    "import flax\n",
    "\n",
    "parallel_train_step = jax.pmap(train_step)\n",
    "\n",
    "num_epochs = 5\n",
    "for epoch in range(num_epochs):\n",
    "    print(f\"Epoch {epoch}\")\n",
    "    for batch in training_iterator:\n",
    "        model_state = parallel_train_step(model_state, batch)\n",
    "\n",
    "    acc = accuracy(flax.jax_utils.unreplicate(model_state), validation_iterator)\n",
    "    print(f\"Accuracy = {acc}\")"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
